{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SDGym Benchmark" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10 | Train Loss -0.649\n", "Epoch 20 | Train Loss -0.688\n", "Epoch 30 | Train Loss -0.700\n", "Epoch 40 | Train Loss -0.707\n", "Epoch 50 | Train Loss -0.709\n", "Epoch 60 | Train Loss -0.731\n", "Epoch 70 | Train Loss -0.727\n", "Epoch 80 | Train Loss -0.728\n", "Epoch 90 | Train Loss -0.720\n", "Epoch 100 | Train Loss -0.731\n", "Epoch 10 | Train Loss -0.622\n", "Epoch 20 | Train Loss -0.677\n", "Epoch 30 | Train Loss -0.697\n", "Epoch 40 | Train Loss -0.715\n", "Epoch 50 | Train Loss -0.718\n", "Epoch 60 | Train Loss -0.715\n", "Epoch 70 | Train Loss -0.719\n", "Epoch 80 | Train Loss -0.711\n", "Epoch 90 | Train Loss -0.714\n", "Epoch 100 | Train Loss -0.717\n", "Epoch 10 | Train Loss -0.647\n", "Epoch 20 | Train Loss -0.687\n", "Epoch 30 | Train Loss -0.701\n", "Epoch 40 | Train Loss -0.711\n", "Epoch 50 | Train Loss -0.707\n", "Epoch 60 | Train Loss -0.728\n", "Epoch 70 | Train Loss -0.720\n", "Epoch 80 | Train Loss -0.732\n", "Epoch 90 | Train Loss -0.726\n", "Epoch 100 | Train Loss -0.724\n", "Epoch 10 | Train Loss 0.058\n", "Epoch 20 | Train Loss 0.023\n", "Epoch 30 | Train Loss -0.052\n", "Epoch 40 | Train Loss -0.045\n", "Epoch 50 | Train Loss -0.085\n", "Epoch 60 | Train Loss -0.110\n", "Epoch 70 | Train Loss -0.086\n", "Epoch 80 | Train Loss -0.104\n", "Epoch 90 | Train Loss -0.127\n", "Epoch 100 | Train Loss -0.134\n", "Epoch 10 | Train Loss 0.043\n", "Epoch 20 | Train Loss -0.001\n", "Epoch 30 | Train Loss -0.035\n", "Epoch 40 | Train Loss -0.044\n", "Epoch 50 | Train Loss -0.036\n", "Epoch 60 | Train Loss -0.051\n", "Epoch 70 | Train Loss -0.143\n", "Epoch 80 | Train Loss -0.118\n", "Epoch 90 | Train Loss -0.130\n", "Epoch 100 | Train Loss -0.113\n", "Epoch 10 | Train Loss 0.037\n", "Epoch 20 | Train Loss -0.038\n", "Epoch 30 | Train Loss -0.093\n", "Epoch 40 | Train Loss -0.075\n", "Epoch 50 | Train Loss -0.152\n", "Epoch 60 | Train Loss -0.146\n", "Epoch 70 | Train Loss -0.146\n", "Epoch 80 | Train Loss -0.173\n", "Epoch 90 | Train Loss -0.146\n", "Epoch 100 | Train Loss -0.137\n", "Epoch 10 | Train Loss -0.077\n", "Epoch 20 | Train Loss -0.087\n", "Epoch 30 | Train Loss -0.121\n", "Epoch 40 | Train Loss -0.153\n", "Epoch 50 | Train Loss -0.204\n", "Epoch 60 | Train Loss -0.207\n", "Epoch 70 | Train Loss -0.245\n", "Epoch 80 | Train Loss -0.240\n", "Epoch 90 | Train Loss -0.240\n", "Epoch 100 | Train Loss -0.224\n", "Epoch 10 | Train Loss -0.111\n", "Epoch 20 | Train Loss -0.133\n", "Epoch 30 | Train Loss -0.143\n", "Epoch 40 | Train Loss -0.200\n", "Epoch 50 | Train Loss -0.214\n", "Epoch 60 | Train Loss -0.187\n", "Epoch 70 | Train Loss -0.205\n", "Epoch 80 | Train Loss -0.217\n", "Epoch 90 | Train Loss -0.222\n", "Epoch 100 | Train Loss -0.189\n", "Epoch 10 | Train Loss -0.106\n", "Epoch 20 | Train Loss -0.138\n", "Epoch 30 | Train Loss -0.126\n", "Epoch 40 | Train Loss -0.199\n", "Epoch 50 | Train Loss -0.214\n", "Epoch 60 | Train Loss -0.203\n", "Epoch 70 | Train Loss -0.222\n", "Epoch 80 | Train Loss -0.161\n", "Epoch 90 | Train Loss -0.223\n", "Epoch 100 | Train Loss -0.240\n", "Epoch 10 | Train Loss -0.083\n", "Epoch 20 | Train Loss -0.094\n", "Epoch 30 | Train Loss -0.105\n", "Epoch 40 | Train Loss -0.117\n", "Epoch 50 | Train Loss -0.126\n", "Epoch 60 | Train Loss -0.129\n", "Epoch 70 | Train Loss -0.118\n", "Epoch 80 | Train Loss -0.143\n", "Epoch 90 | Train Loss -0.143\n", "Epoch 100 | Train Loss -0.145\n", "Epoch 10 | Train Loss -0.070\n", "Epoch 20 | Train Loss -0.106\n", "Epoch 30 | Train Loss -0.100\n", "Epoch 40 | Train Loss -0.122\n", "Epoch 50 | Train Loss -0.125\n", "Epoch 60 | Train Loss -0.135\n", "Epoch 70 | Train Loss -0.138\n", "Epoch 80 | Train Loss -0.124\n", "Epoch 90 | Train Loss -0.142\n", "Epoch 100 | Train Loss -0.144\n", "Epoch 10 | Train Loss -0.077\n", "Epoch 20 | Train Loss -0.100\n", "Epoch 30 | Train Loss -0.116\n", "Epoch 40 | Train Loss -0.128\n", "Epoch 50 | Train Loss -0.121\n", "Epoch 60 | Train Loss -0.146\n", "Epoch 70 | Train Loss -0.147\n", "Epoch 80 | Train Loss -0.150\n", "Epoch 90 | Train Loss -0.147\n", "Epoch 100 | Train Loss -0.148\n", "Epoch 10 | Train Loss 0.017\n", "Epoch 20 | Train Loss -0.029\n", "Epoch 30 | Train Loss -0.040\n", "Epoch 40 | Train Loss -0.033\n", "Epoch 50 | Train Loss -0.072\n", "Epoch 60 | Train Loss -0.064\n", "Epoch 70 | Train Loss -0.084\n", "Epoch 80 | Train Loss -0.067\n", "Epoch 90 | Train Loss -0.082\n", "Epoch 100 | Train Loss -0.076\n", "Epoch 10 | Train Loss -0.003\n", "Epoch 20 | Train Loss -0.022\n", "Epoch 30 | Train Loss -0.053\n", "Epoch 40 | Train Loss -0.076\n", "Epoch 50 | Train Loss -0.049\n", "Epoch 60 | Train Loss -0.090\n", "Epoch 70 | Train Loss -0.095\n", "Epoch 80 | Train Loss -0.105\n", "Epoch 90 | Train Loss -0.107\n", "Epoch 100 | Train Loss -0.120\n", "Epoch 10 | Train Loss 0.011\n", "Epoch 20 | Train Loss -0.022\n", "Epoch 30 | Train Loss -0.049\n", "Epoch 40 | Train Loss -0.055\n", "Epoch 50 | Train Loss -0.059\n", "Epoch 60 | Train Loss -0.067\n", "Epoch 70 | Train Loss -0.084\n", "Epoch 80 | Train Loss -0.095\n", "Epoch 90 | Train Loss -0.098\n", "Epoch 100 | Train Loss -0.090\n", "Epoch 10 | Train Loss -0.102\n", "Epoch 20 | Train Loss -0.129\n", "Epoch 30 | Train Loss -0.138\n", "Epoch 40 | Train Loss -0.148\n", "Epoch 50 | Train Loss -0.171\n", "Epoch 60 | Train Loss -0.144\n", "Epoch 70 | Train Loss -0.201\n", "Epoch 80 | Train Loss -0.182\n", "Epoch 90 | Train Loss -0.216\n", "Epoch 100 | Train Loss -0.252\n", "Epoch 10 | Train Loss -0.057\n", "Epoch 20 | Train Loss -0.089\n", "Epoch 30 | Train Loss -0.115\n", "Epoch 40 | Train Loss -0.120\n", "Epoch 50 | Train Loss -0.140\n", "Epoch 60 | Train Loss -0.174\n", "Epoch 70 | Train Loss -0.151\n", "Epoch 80 | Train Loss -0.199\n", "Epoch 90 | Train Loss -0.199\n", "Epoch 100 | Train Loss -0.201\n", "Epoch 10 | Train Loss -0.075\n", "Epoch 20 | Train Loss -0.092\n", "Epoch 30 | Train Loss -0.110\n", "Epoch 40 | Train Loss -0.118\n", "Epoch 50 | Train Loss -0.166\n", "Epoch 60 | Train Loss -0.163\n", "Epoch 70 | Train Loss -0.140\n", "Epoch 80 | Train Loss -0.159\n", "Epoch 90 | Train Loss -0.161\n", "Epoch 100 | Train Loss -0.158\n" ] } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "import sdgym\n", "from echoflow import EchoFlow\n", "\n", "def EchoFlowSynthesizer(real_data, categorical_columns, ordinal_columns):\n", " df = pd.DataFrame(real_data)\n", " for i in categorical_columns+ordinal_columns:\n", " df[i] = df[i].astype(int).astype(str)\n", " \n", " model = EchoFlow(nb_epochs=100)\n", " model.fit(df)\n", " new_df = model.sample(num_samples=len(df))\n", " \n", " for i in categorical_columns+ordinal_columns:\n", " new_df[i] = new_df[i].astype(int)\n", " arr = new_df.values\n", " \n", " return arr\n", "\n", "def EchoFlowSynthesizeKDE(real_data, categorical_columns, ordinal_columns):\n", " df = pd.DataFrame(real_data)\n", " for i in categorical_columns+ordinal_columns:\n", " df[i] = df[i].astype(int).astype(str)\n", " \n", " model = EchoFlow(nb_epochs=100, use_kde=True)\n", " model.fit(df)\n", " new_df = model.sample(num_samples=len(df))\n", " \n", " for i in categorical_columns+ordinal_columns:\n", " new_df[i] = new_df[i].astype(int)\n", " arr = new_df.values\n", " \n", " return arr\n", "\n", "scores = sdgym.run(synthesizers=[\n", " EchoFlowSynthesizer, \n", " EchoFlowSynthesizeKDE\n", "], datasets=['ring', 'grid', 'gridr'], iterations=3)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
grid/syn_likelihoodgrid/test_likelihoodgridr/syn_likelihoodgridr/test_likelihoodring/syn_likelihoodring/test_likelihoodtimestamp
CTGAN-8.760635-5.062972-8.309750-5.048310-6.591324-2.6652812020-10-17 09:46:54.494331
EchoFlowSynthesizer-6.712230-4.437056-6.496902-4.475942-1.932969-1.7968322020-12-30 23:10:22.816115
EchoFlowSynthesizeKDE-5.402527-4.063265-5.531003-4.154107-2.277480-1.8423712020-12-30 23:10:22.816115
\n", "
" ], "text/plain": [ " grid/syn_likelihood grid/test_likelihood \\\n", "CTGAN -8.760635 -5.062972 \n", "EchoFlowSynthesizer -6.712230 -4.437056 \n", "EchoFlowSynthesizeKDE -5.402527 -4.063265 \n", "\n", " gridr/syn_likelihood gridr/test_likelihood \\\n", "CTGAN -8.309750 -5.048310 \n", "EchoFlowSynthesizer -6.496902 -4.475942 \n", "EchoFlowSynthesizeKDE -5.531003 -4.154107 \n", "\n", " ring/syn_likelihood ring/test_likelihood \\\n", "CTGAN -6.591324 -2.665281 \n", "EchoFlowSynthesizer -1.932969 -1.796832 \n", "EchoFlowSynthesizeKDE -2.277480 -1.842371 \n", "\n", " timestamp \n", "CTGAN 2020-10-17 09:46:54.494331 \n", "EchoFlowSynthesizer 2020-12-30 23:10:22.816115 \n", "EchoFlowSynthesizeKDE 2020-12-30 23:10:22.816115 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scores.loc[[\"CTGAN\", \"EchoFlowSynthesizer\", \"EchoFlowSynthesizeKDE\"]]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.8" } }, "nbformat": 4, "nbformat_minor": 4 }